/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.typeserializerupgrade;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.runtime.state.StateBackendLoader;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamMap;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.DynamicCodeLoadingException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.StateMigrationException;
import org.apache.flink.util.TestLogger;

import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import javax.tools.JavaCompiler;
import javax.tools.ToolProvider;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * Tests the state migration behaviour when the underlying POJO type changes and one tries to
 * recover from old state.
 */
@RunWith(Parameterized.class)
public class PojoSerializerUpgradeTest extends TestLogger {

    @Parameterized.Parameters(name = "StateBackend: {0}")
    public static Collection<String> parameters() {
        return Arrays.asList(
                StateBackendLoader.HASHMAP_STATE_BACKEND_NAME,
                StateBackendLoader.ROCKSDB_STATE_BACKEND_NAME);
    }

    @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();

    private StateBackend stateBackend;

    public PojoSerializerUpgradeTest(String backendType)
            throws IOException, DynamicCodeLoadingException {
        Configuration config = new Configuration();
        config.set(StateBackendOptions.STATE_BACKEND, backendType);
        config.set(
                CheckpointingOptions.CHECKPOINTS_DIRECTORY,
                temporaryFolder.newFolder().toURI().toString());
        stateBackend =
                StateBackendLoader.loadStateBackendFromConfig(
                        config, Thread.currentThread().getContextClassLoader(), null);
    }

    private static final String POJO_NAME = "Pojo";

    private static final String SOURCE_A =
            "import java.util.Objects;"
                    + "public class Pojo { "
                    + "private long a; "
                    + "private String b; "
                    + "public long getA() { return a;} "
                    + "public void setA(long value) { a = value; }"
                    + "public String getB() { return b; }"
                    + "public void setB(String value) { b = value; }"
                    + "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
                    + "@Override public int hashCode() { return Objects.hash(a, b); } "
                    + "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

    // changed order of fields which should be recoverable
    private static final String SOURCE_B =
            "import java.util.Objects;"
                    + "public class Pojo { "
                    + "private String b; "
                    + "private long a; "
                    + "public long getA() { return a;} "
                    + "public void setA(long value) { a = value; }"
                    + "public String getB() { return b; }"
                    + "public void setB(String value) { b = value; }"
                    + "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
                    + "@Override public int hashCode() { return Objects.hash(a, b); } "
                    + "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

    // changed type of a field which should not be recoverable
    private static final String SOURCE_C =
            "import java.util.Objects;"
                    + "public class Pojo { "
                    + "private double a; "
                    + "private String b; "
                    + "public double getA() { return a;} "
                    + "public void setA(double value) { a = value; }"
                    + "public String getB() { return b; }"
                    + "public void setB(String value) { b = value; }"
                    + "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b);} else { return false; }}"
                    + "@Override public int hashCode() { return Objects.hash(a, b); } "
                    + "@Override public String toString() {return \"(\" + a + \", \" + b + \")\";}}";

    // additional field which should not be recoverable
    private static final String SOURCE_D =
            "import java.util.Objects;"
                    + "public class Pojo { "
                    + "private long a; "
                    + "private String b; "
                    + "private double c; "
                    + "public long getA() { return a;} "
                    + "public void setA(long value) { a = value; }"
                    + "public String getB() { return b; }"
                    + "public void setB(String value) { b = value; }"
                    + "public double getC() { return c; } "
                    + "public void setC(double value) { c = value; }"
                    + "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a && b.equals(other.b) && c == other.c;} else { return false; }}"
                    + "@Override public int hashCode() { return Objects.hash(a, b, c); } "
                    + "@Override public String toString() {return \"(\" + a + \", \" + b + \", \" + c + \")\";}}";

    // missing field which should not be recoverable
    private static final String SOURCE_E =
            "import java.util.Objects;"
                    + "public class Pojo { "
                    + "private long a; "
                    + "public long getA() { return a;} "
                    + "public void setA(long value) { a = value; }"
                    + "@Override public boolean equals(Object obj) { if (obj instanceof Pojo) { Pojo other = (Pojo) obj; return a == other.a;} else { return false; }}"
                    + "@Override public int hashCode() { return Objects.hash(a); } "
                    + "@Override public String toString() {return \"(\" + a + \")\";}}";

    /** We should be able to handle a changed field order of a POJO as keyed state. */
    @Test
    public void testChangedFieldOrderWithKeyedState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, true);
    }

    /** We should be able to handle a changed field order of a POJO as operator state. */
    @Test
    public void testChangedFieldOrderWithOperatorState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_B, true, false);
    }

    /** Changing field types of a POJO as keyed state should require a state migration. */
    @Test
    public void testChangedFieldTypesWithKeyedState() throws Exception {
        try {
            testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, true);
            fail("Expected a state migration exception.");
        } catch (Exception e) {
            if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
                // StateMigrationException expected
            } else {
                throw e;
            }
        }
    }

    /** Changing field types of a POJO as operator state should require a state migration. */
    @Test
    public void testChangedFieldTypesWithOperatorState() throws Exception {
        try {
            testPojoSerializerUpgrade(SOURCE_A, SOURCE_C, true, false);
            fail("Expected a state migration exception.");
        } catch (Exception e) {
            if (CommonTestUtils.containsCause(e, StateMigrationException.class)) {
                // StateMigrationException expected
            } else {
                throw e;
            }
        }
    }

    /** Adding fields to a POJO as keyed state should succeed. */
    @Test
    public void testAdditionalFieldWithKeyedState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, true);
    }

    /** Adding fields to a POJO as operator state should succeed. */
    @Test
    public void testAdditionalFieldWithOperatorState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_D, true, false);
    }

    /** Removing fields from a POJO as keyed state should succeed. */
    @Test
    public void testMissingFieldWithKeyedState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, true);
    }

    /** Removing fields from a POJO as operator state should succeed. */
    @Test
    public void testMissingFieldWithOperatorState() throws Exception {
        testPojoSerializerUpgrade(SOURCE_A, SOURCE_E, false, false);
    }

    private void testPojoSerializerUpgrade(
            String classSourceA, String classSourceB, boolean hasBField, boolean isKeyedState)
            throws Exception {
        final Configuration taskConfiguration = new Configuration();
        final ExecutionConfig executionConfig = new ExecutionConfig();
        final KeySelector<Long, Long> keySelector = new IdentityKeySelector<>();
        final Collection<Long> inputs = Arrays.asList(1L, 2L, 45L, 67L, 1337L);

        // run the program with classSourceA
        File rootPath = temporaryFolder.newFolder();
        File sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceA);
        compileClass(sourceFile);

        final ClassLoader classLoader =
                URLClassLoader.newInstance(
                        new URL[] {rootPath.toURI().toURL()},
                        Thread.currentThread().getContextClassLoader());

        OperatorSubtaskState stateHandles =
                runOperator(
                        taskConfiguration,
                        executionConfig,
                        new StreamMap<>(new StatefulMapper(isKeyedState, false, hasBField)),
                        keySelector,
                        isKeyedState,
                        stateBackend,
                        classLoader,
                        null,
                        inputs);

        // run the program with classSourceB
        rootPath = temporaryFolder.newFolder();

        sourceFile = writeSourceFile(rootPath, POJO_NAME + ".java", classSourceB);
        compileClass(sourceFile);

        final ClassLoader classLoaderB =
                URLClassLoader.newInstance(
                        new URL[] {rootPath.toURI().toURL()},
                        Thread.currentThread().getContextClassLoader());

        runOperator(
                taskConfiguration,
                executionConfig,
                new StreamMap<>(new StatefulMapper(isKeyedState, true, hasBField)),
                keySelector,
                isKeyedState,
                stateBackend,
                classLoaderB,
                stateHandles,
                inputs);
    }

    private OperatorSubtaskState runOperator(
            Configuration taskConfiguration,
            ExecutionConfig executionConfig,
            OneInputStreamOperator<Long, Long> operator,
            KeySelector<Long, Long> keySelector,
            boolean isKeyedState,
            StateBackend stateBackend,
            ClassLoader classLoader,
            OperatorSubtaskState operatorSubtaskState,
            Iterable<Long> input)
            throws Exception {

        try (final MockEnvironment environment =
                new MockEnvironmentBuilder()
                        .setTaskName("test task")
                        .setManagedMemorySize(32 * 1024)
                        .setInputSplitProvider(new MockInputSplitProvider())
                        .setBufferSize(256)
                        .setTaskConfiguration(taskConfiguration)
                        .setExecutionConfig(executionConfig)
                        .setMaxParallelism(16)
                        .setUserCodeClassLoader(classLoader)
                        .build()) {

            OneInputStreamOperatorTestHarness<Long, Long> harness = null;
            try {
                if (isKeyedState) {
                    harness =
                            new KeyedOneInputStreamOperatorTestHarness<>(
                                    operator,
                                    keySelector,
                                    BasicTypeInfo.LONG_TYPE_INFO,
                                    environment);
                } else {
                    harness =
                            new OneInputStreamOperatorTestHarness<>(
                                    operator, LongSerializer.INSTANCE, environment);
                }

                harness.setStateBackend(stateBackend);

                harness.setup();
                harness.initializeState(operatorSubtaskState);
                harness.open();

                long timestamp = 0L;

                for (Long value : input) {
                    harness.processElement(value, timestamp++);
                }

                long checkpointId = 1L;
                long checkpointTimestamp = timestamp + 1L;

                return harness.snapshot(checkpointId, checkpointTimestamp);
            } finally {
                IOUtils.closeQuietly(harness);
            }
        }
    }

    private static File writeSourceFile(File root, String name, String source) throws IOException {
        File sourceFile = new File(root, name);

        sourceFile.getParentFile().mkdirs();

        try (FileWriter writer = new FileWriter(sourceFile)) {
            writer.write(source);
        }

        return sourceFile;
    }

    private static int compileClass(File sourceFile) {
        JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
        return compiler.run(null, null, null, "-proc:none", sourceFile.getPath());
    }

    private static final class StatefulMapper extends RichMapFunction<Long, Long>
            implements CheckpointedFunction {

        private static final long serialVersionUID = -520490739059396832L;

        private final boolean keyed;
        private final boolean verify;
        private final boolean hasBField;

        // keyed states
        private transient ValueState<Object> keyedValueState;
        private transient ListState<Object> keyedListState;
        private transient ReducingState<Object> keyedReducingState;

        // operator states
        private transient ListState<Object> partitionableListState;
        private transient ListState<Object> unionListState;

        private transient Class<?> pojoClass;
        private transient Field fieldA;
        private transient Field fieldB;

        StatefulMapper(boolean keyed, boolean verify, boolean hasBField) {
            this.keyed = keyed;
            this.verify = verify;
            this.hasBField = hasBField;
        }

        @Override
        public Long map(Long value) throws Exception {
            Object pojo = pojoClass.newInstance();

            fieldA.set(pojo, value);

            if (hasBField) {
                fieldB.set(pojo, value + "");
            }

            if (verify) {
                if (keyed) {
                    assertEquals(pojo, keyedValueState.value());

                    Iterator<Object> listIterator = keyedListState.get().iterator();

                    boolean elementFound = false;

                    while (listIterator.hasNext()) {
                        elementFound |= pojo.equals(listIterator.next());
                    }

                    assertTrue(elementFound);

                    assertEquals(pojo, keyedReducingState.get());
                } else {
                    boolean elementFound = false;
                    Iterator<Object> listIterator = partitionableListState.get().iterator();
                    while (listIterator.hasNext()) {
                        elementFound |= pojo.equals(listIterator.next());
                    }
                    assertTrue(elementFound);

                    elementFound = false;
                    listIterator = unionListState.get().iterator();
                    while (listIterator.hasNext()) {
                        elementFound |= pojo.equals(listIterator.next());
                    }
                    assertTrue(elementFound);
                }
            } else {
                if (keyed) {
                    keyedValueState.update(pojo);
                    keyedListState.add(pojo);
                    keyedReducingState.add(pojo);
                } else {
                    partitionableListState.add(pojo);
                    unionListState.add(pojo);
                }
            }

            return value;
        }

        @Override
        public void snapshotState(FunctionSnapshotContext context) throws Exception {}

        @SuppressWarnings("unchecked")
        @Override
        public void initializeState(FunctionInitializationContext context) throws Exception {
            pojoClass = getRuntimeContext().getUserCodeClassLoader().loadClass(POJO_NAME);

            fieldA = pojoClass.getDeclaredField("a");
            fieldA.setAccessible(true);

            if (hasBField) {
                fieldB = pojoClass.getDeclaredField("b");
                fieldB.setAccessible(true);
            }

            if (keyed) {
                keyedValueState =
                        context.getKeyedStateStore()
                                .getState(
                                        new ValueStateDescriptor<>(
                                                "keyedValueState", (Class<Object>) pojoClass));
                keyedListState =
                        context.getKeyedStateStore()
                                .getListState(
                                        new ListStateDescriptor<>(
                                                "keyedListState", (Class<Object>) pojoClass));

                ReduceFunction<Object> reduceFunction = new FirstValueReducer<>();
                keyedReducingState =
                        context.getKeyedStateStore()
                                .getReducingState(
                                        new ReducingStateDescriptor<>(
                                                "keyedReducingState",
                                                reduceFunction,
                                                (Class<Object>) pojoClass));
            } else {
                partitionableListState =
                        context.getOperatorStateStore()
                                .getListState(
                                        new ListStateDescriptor<>(
                                                "partitionableListState",
                                                (Class<Object>) pojoClass));
                unionListState =
                        context.getOperatorStateStore()
                                .getUnionListState(
                                        new ListStateDescriptor<>(
                                                "unionListState", (Class<Object>) pojoClass));
            }
        }
    }

    private static final class FirstValueReducer<T> implements ReduceFunction<T> {

        private static final long serialVersionUID = -9222976423336835926L;

        @Override
        public T reduce(T value1, T value2) throws Exception {
            return value1;
        }
    }

    private static final class IdentityKeySelector<T> implements KeySelector<T, T> {

        private static final long serialVersionUID = -3263628393881929147L;

        @Override
        public T getKey(T value) throws Exception {
            return value;
        }
    }
}
